//===-- Single-precision general sinhf/coshf functions --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_LIBC_SRC___SUPPORT_MATH_SINHFCOSHF_UTILS_H
#define LLVM_LIBC_SRC___SUPPORT_MATH_SINHFCOSHF_UTILS_H

#include "exp10f_utils.h"
#include "src/__support/FPUtil/multiply_add.h"

namespace LIBC_NAMESPACE_DECL {

namespace math {

namespace sinhfcoshf_internal {

// The function correctly calculates sinh(x) and cosh(x) by calculating exp(x)
// and exp(-x) simultaneously.
// To compute e^x, we perform the following range
// reduction: find hi, mid, lo such that:
//   x = (hi + mid) * log(2) + lo, in which
//     hi is an integer,
//     0 <= mid * 2^5 < 32 is an integer
//     -2^(-6) <= lo * log2(e) <= 2^-6.
// In particular,
//   hi + mid = round(x * log2(e) * 2^5) * 2^(-5).
// Then,
//   e^x = 2^(hi + mid) * e^lo = 2^hi * 2^mid * e^lo.
// 2^mid is stored in the lookup table of 32 elements.
// e^lo is computed using a degree-5 minimax polynomial
// generated by Sollya:
//   e^lo ~ P(lo) = 1 + lo + c2 * lo^2 + ... + c5 * lo^5
//        = (1 + c2*lo^2 + c4*lo^4) + lo * (1 + c3*lo^2 + c5*lo^4)
//        = P_even + lo * P_odd
// We perform 2^hi * 2^mid by simply add hi to the exponent field
// of 2^mid.
// To compute e^(-x), notice that:
//   e^(-x) = 2^(-(hi + mid)) * e^(-lo)
//          ~ 2^(-(hi + mid)) * P(-lo)
//          = 2^(-(hi + mid)) * (P_even - lo * P_odd)
// So:
//   sinh(x) = (e^x - e^(-x)) / 2
//           ~ 0.5 * (2^(hi + mid) * (P_even + lo * P_odd) -
//                    2^(-(hi + mid)) * (P_even - lo * P_odd))
//           = 0.5 * (P_even * (2^(hi + mid) - 2^(-(hi + mid))) +
//                    lo * P_odd * (2^(hi + mid) + 2^(-(hi + mid))))
// And similarly:
//   cosh(x) = (e^x + e^(-x)) / 2
//           ~ 0.5 * (P_even * (2^(hi + mid) + 2^(-(hi + mid))) +
//                    lo * P_odd * (2^(hi + mid) - 2^(-(hi + mid))))
// The main point of these formulas is that the expensive part of calculating
// the polynomials approximating lower parts of e^(x) and e^(-x) are shared
// and only done once.
template <bool is_sinh> LIBC_INLINE double exp_pm_eval(float x) {
  double xd = static_cast<double>(x);

  // kd = round(x * log2(e) * 2^5)
  // k_p = round(x * log2(e) * 2^5)
  // k_m = round(-x * log2(e) * 2^5)
  double kd;
  int k_p, k_m;

#ifdef LIBC_TARGET_CPU_HAS_NEAREST_INT
  kd = fputil::nearest_integer(ExpBase::LOG2_B * xd);
  k_p = static_cast<int>(kd);
  k_m = -k_p;
#else
  constexpr double HALF_WAY[2] = {0.5, -0.5};

  k_p = static_cast<int>(
      fputil::multiply_add(xd, ExpBase::LOG2_B, HALF_WAY[x < 0.0f]));
  k_m = -k_p;
  kd = static_cast<double>(k_p);
#endif // LIBC_TARGET_CPU_HAS_NEAREST_INT

  // hi = floor(kf * 2^(-5))
  // exp_hi = shift hi to the exponent field of double precision.
  int64_t exp_hi_p = static_cast<int64_t>((k_p >> ExpBase::MID_BITS))
                     << fputil::FPBits<double>::FRACTION_LEN;
  int64_t exp_hi_m = static_cast<int64_t>((k_m >> ExpBase::MID_BITS))
                     << fputil::FPBits<double>::FRACTION_LEN;
  // mh_p = 2^(hi + mid)
  // mh_m = 2^(-(hi + mid))
  // mh_bits_* = bit field of mh_*
  int64_t mh_bits_p = ExpBase::EXP_2_MID[k_p & ExpBase::MID_MASK] + exp_hi_p;
  int64_t mh_bits_m = ExpBase::EXP_2_MID[k_m & ExpBase::MID_MASK] + exp_hi_m;
  double mh_p = fputil::FPBits<double>(uint64_t(mh_bits_p)).get_val();
  double mh_m = fputil::FPBits<double>(uint64_t(mh_bits_m)).get_val();
  // mh_sum = 2^(hi + mid) + 2^(-(hi + mid))
  double mh_sum = mh_p + mh_m;
  // mh_diff = 2^(hi + mid) - 2^(-(hi + mid))
  double mh_diff = mh_p - mh_m;

  // dx = lo = x - (hi + mid) * log(2)
  double dx =
      fputil::multiply_add(kd, ExpBase::M_LOGB_2_LO,
                           fputil::multiply_add(kd, ExpBase::M_LOGB_2_HI, xd));
  double dx2 = dx * dx;

  // c0 = 1 + COEFFS[0] * lo^2
  // P_even = (1 + COEFFS[0] * lo^2 + COEFFS[2] * lo^4) / 2
  double p_even = fputil::polyeval(dx2, 0.5, ExpBase::COEFFS[0] * 0.5,
                                   ExpBase::COEFFS[2] * 0.5);
  // P_odd = (1 + COEFFS[1] * lo^2 + COEFFS[3] * lo^4) / 2
  double p_odd = fputil::polyeval(dx2, 0.5, ExpBase::COEFFS[1] * 0.5,
                                  ExpBase::COEFFS[3] * 0.5);

  double r;
  if constexpr (is_sinh)
    r = fputil::multiply_add(dx * mh_sum, p_odd, p_even * mh_diff);
  else
    r = fputil::multiply_add(dx * mh_diff, p_odd, p_even * mh_sum);
  return r;
}

} // namespace sinhfcoshf_internal

} // namespace math

} // namespace LIBC_NAMESPACE_DECL

#endif // LLVM_LIBC_SRC___SUPPORT_MATH_SINHFCOSHF_UTILS_H
